from datetime import datetime
import time
import math
import os
import subprocess
from faster_whisper import WhisperModel
from zhconv import convert
import argparse
import ffmpeg

parser = argparse.ArgumentParser()
parser.add_argument('--model_size', type=str, default='large-v2')
parser.add_argument('--language', type=str, default=None)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--compute_type', type=str, default='float16')
parser.add_argument('--input_path', type=str, default='input')
parser.add_argument('--output_path', type=str, default=os.getcwd() + '\\output')
args = parser.parse_args()


def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"新建{path}文件夹")
    else:
        print(f"已有{path}文件夹")


def environment_initializing():
    now_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    create_directory(f'{args.output_path}\\{now_time}')
    create_directory(f'{args.output_path}\\{now_time}\\subtitle')
    create_directory(f'{args.output_path}\\{now_time}\\video')
    if os.path.exists(f'model/{args.model_size}'):
        model = WhisperModel(f'model/{args.model_size}', device=args.device, compute_type=args.compute_type)
    else:
        model = WhisperModel(args.model_size, device=args.device, compute_type=args.compute_type)
    return now_time, model


def video2audio(src):
    command = f"ffmpeg -i \"{src}\" -vn -c:a libmp3lame -ar 16000 -y \"tmp\\.mp3\""
    data = subprocess.run(command, capture_output=True, check=True)
    if not data.returncode:
        print("音频分离成功")


def audio2sentences(src):
    segments, info = model.transcribe(src, language=args.language, word_timestamps=True, append_punctuations='')
    id = 0
    for i in segments:
        print(i.words[0].start, i.words[-1].end)
        print(i.text)
        is_at_the_beginning_of_setence = True
        setences = []
        end = None
        sentence_info = None
        for j in i.words:
            # 检查word是否是sentence开头
            if is_at_the_beginning_of_setence:
                sentence_info = {'text': '', 'start': j.start}
                is_at_the_beginning_of_setence = False

            word = convert(j.word, 'zh-hans')
            end = j.end
            sentence_info['text'] += word

            # 以标点符号检查word是否是sentence结尾
            if word[-1] == ',':
                is_at_the_beginning_of_setence = True
                sentence_info.update({'end': end})
                sentence_info['text'] = sentence_info['text'][:-1]
                setences.append(sentence_info)

        # 若无标点符号但sentence结尾
        if not sentence_info.get('end', None):
            sentence_info.update({'end': end})
            setences.append(sentence_info)

        # 将sentences分sentence转化为srt格式字幕
        for j in setences:
            print(j)
            sentence2srt(j, id)
            id += 1
        print('--------------------')


def sentence2srt(result, id):
    with open(f'{args.output_path}/{now_time}/subtitle/{file_name}.srt', 'a', encoding='utf-8') as f:
        sentence_timestamp = []
        for j in (result['start'], result['end']):
            m, s = divmod(j, 60)
            h, m = divmod(m, 60)
            ms, s = math.modf(s)
            sentence_timestamp.append((int(h), int(m), int(s), int(ms * 1000)))
        f.write(f'{id}\n')
        f.write(
            '{:0>2d}:{:0>2d}:{:0>2d},{:0>3d} --> {:0>2d}:{:0>2d}:{:0>2d},{:0>3d}\n'.format(sentence_timestamp[0][0],
                                                                                           sentence_timestamp[0][1],
                                                                                           sentence_timestamp[0][2],
                                                                                           sentence_timestamp[0][3],
                                                                                           sentence_timestamp[1][0],
                                                                                           sentence_timestamp[1][1],
                                                                                           sentence_timestamp[1][2],
                                                                                           sentence_timestamp[1][3]))
        f.write('{}\n\n'.format(result['text']))


def srt_added_to_video(src):
    srt_path = f'{args.output_path}\\{now_time}\\subtitle\\{file_name}.srt'
    if os.path.exists(srt_path):
        command = f"ffmpeg -i \"{src}\" -i \"{args.output_path}\\{now_time}\\subtitle\\{file_name}.srt\" -c copy -c:s mov_text -y \"{args.output_path}\\{now_time}\\video\\{file_name}.mp4\""
        data = subprocess.run(command, capture_output=True, check=True)
        if not data.returncode:
            print("字幕添加至视频成功")


def create_blank_video(src):
    probe = ffmpeg.probe(src)
    stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), None)
    duration = float(stream['duration'])
    print(f"duration:{duration}")
    command = f"ffmpeg -f lavfi -i color=s=320x180:r=25:color=black:duration={duration} -y \"tmp\\.mp4\""
    data = subprocess.run(command, capture_output=True, check=True)
    if not data.returncode:
        print("空白视频创建成功")
    merge_audio_and_blank_video(src)


def merge_audio_and_blank_video(src):
    command = f"ffmpeg -i \"tmp\\.mp4\" -i \"{src}\" -c copy -y \"tmp\output.mp4\""
    data = subprocess.run(command, capture_output=True, check=True)
    if not data.returncode:
        print("音频和空白视频合并成功")


now_time, model = environment_initializing()
for j, i in enumerate(os.listdir(f'{args.input_path}/.')):
    print('-' * 8 + f'处理第{j}个视频' + '-' * 8)
    print(i)
    t1 = time.time()
    file_name = i.rsplit('.', 1)[0]
    if i.endswith('.mp3') or i.endswith('.aac'):
        create_blank_video(f"{args.input_path}\\{i}")
        audio2sentences(f"{args.input_path}\\{i}")
        srt_added_to_video('tmp/output.mp4')

    else:
        video2audio(f"{args.input_path}\\{i}")
        audio2sentences('tmp/.mp3')
        srt_added_to_video(f"{args.input_path}\\{i}")

    print(f'耗时:{time.time() - t1}')
    j += 1
